{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chat-嬛嬛 v2.0 运行示例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一、对话集生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "! bash ../../generation_dataset/generation.sh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二、模型微调"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "! bash ../../fine_tune/lora/train.sh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三、通过代码调用微调模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**ChatGLM**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/autodl-tmp/env/huanhuan-chat/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModel, AutoTokenizer\n",
    "import torch\n",
    "from peft import PeftModel\n",
    "\n",
    "# 本地底座模型参数路径\n",
    "MODEL_PATH = \"model path for ChatGLM\"\n",
    "# 微调参数路径\n",
    "LoRA_PATH = \"../../dataset/output\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n",
      "Loading checkpoint shards: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it]\n"
     ]
    }
   ],
   "source": [
    "# 加载底座模型\n",
    "model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map='auto')\n",
    "# 加载 tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)\n",
    "# 加载微调模型\n",
    "model = PeftModel.from_pretrained(model, LoRA_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_completion(prompt, temprature=0, max_length=4096):\n",
    "    with torch.no_grad():\n",
    "        ids = tokenizer.encode(prompt)\n",
    "        input_ids = torch.LongTensor([ids])\n",
    "        out = model.generate(\n",
    "            input_ids=input_ids,\n",
    "            max_length=max_length,\n",
    "            do_sample=False,\n",
    "            temperature=temprature\n",
    "        )\n",
    "        out_text = tokenizer.decode(out[0])\n",
    "        answer = out_text.replace(prompt, \"\").replace(\"\\nEND\", \"\").strip()\n",
    "        return answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/autodl-tmp/env/huanhuan-chat/lib/python3.10/site-packages/transformers/generation/utils.py:1452: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "我是甄嬛，家父是大理寺少卿甄远道。\n"
     ]
    }
   ],
   "source": [
    "prompt = \"你是谁？\"\n",
    "print(get_completion(prompt))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**BaiChuan**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/autodl-tmp/env/huanhuan-chat/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "from peft import PeftModel\n",
    "\n",
    "# 本地底座模型参数路径\n",
    "MODEL_PATH = \"model path for BaiChuan\"\n",
    "# 微调参数路径\n",
    "LoRA_PATH = \"../../dataset/output-baichuan\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载底座模型\n",
    "model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True).half().cuda()\n",
    "# 加载 tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)\n",
    "# 加载微调模型\n",
    "model = PeftModel.from_pretrained(model, LoRA_PATH).half()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_completion(prompt, temprature=0, max_length=4096):\n",
    "    with torch.no_grad():\n",
    "        inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "        out = model.generate(\n",
    "            **inputs,\n",
    "            max_length=max_length,\n",
    "            do_sample=False,\n",
    "            temperature=temprature\n",
    "        )\n",
    "        answer = tokenizer.decode(out[0][len(inputs.input_ids[0]):])\n",
    "        return answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 我是甄嬛，家父是大理寺少卿甄远道。</s>\n"
     ]
    }
   ],
   "source": [
    "prompt = \"你是谁？\"\n",
    "print(get_completion(prompt))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "huanhuan-chat",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
